{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:12.621610Z",
     "start_time": "2024-11-18T10:42:11.959958Z"
    },
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import asyncio\n",
    "import os\n",
    "import time\n",
    "from dataclasses import dataclass, field\n",
    "from typing import List, Tuple\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "from utils import generate_arithmetic_expression, re_parse_json, calculate_time_difference\n",
    "\n",
    "# 加载 .env 文件\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aaa90b21154473a3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:14.504861Z",
     "start_time": "2024-11-18T10:42:14.501701Z"
    }
   },
   "outputs": [],
   "source": [
    "api_keys = os.getenv('API_KEY').split(\",\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a6420fb448312ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:16.036926Z",
     "start_time": "2024-11-18T10:42:16.033125Z"
    }
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class LLMAPI:\n",
    "    \"\"\"\n",
    "    cnt: 统计每一个token的调用次数\n",
    "    \"\"\"\n",
    "    api_key: str\n",
    "    uid: int\n",
    "    cnt: int = 0\n",
    "    llm: ChatOpenAI = field(init=False)  # 自动创建的对象，不需要用户传入\n",
    "\n",
    "    def __post_init__(self):\n",
    "        # 在这里初始化 llm 对象\n",
    "        self.llm = self.create_llm()\n",
    "\n",
    "    def create_llm(self):\n",
    "        # 模拟创建 llm 对象的逻辑\n",
    "        return ChatOpenAI(\n",
    "            model=\"gpt-4o-mini\",\n",
    "            base_url=\"https://api.chatfire.cn/v1/\",\n",
    "            api_key=self.api_key,\n",
    "        )\n",
    "\n",
    "    async def agenerate(self, text):\n",
    "        self.cnt += 1\n",
    "        # return await self.llm.agenerate([text]).generations[0][0].text\n",
    "        res = await self.llm.agenerate([text])\n",
    "        # print(res)\n",
    "        # return res.generations[0][0].text\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbeb681aefa924da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f492ae28cc3444a4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:18.840185Z",
     "start_time": "2024-11-18T10:42:18.837656Z"
    }
   },
   "outputs": [],
   "source": [
    "async def call_llm(llm: LLMAPI, text: str):\n",
    "    return await llm.agenerate(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe8352b4d44ea795",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:20.436236Z",
     "start_time": "2024-11-18T10:42:20.426036Z"
    }
   },
   "outputs": [],
   "source": [
    "async def run_api(keys: List[str], data: List[str]) -> Tuple[List[str], List[LLMAPI]]:\n",
    "    # 创建LLM\n",
    "    llms = [LLMAPI(api_key=key, uid=i) for i, key in enumerate(keys)]\n",
    "\n",
    "    results = [call_llm(llms[i % len(llms)], text) for i, text in enumerate(data)]\n",
    "    # 结果按序返回\n",
    "    results = await asyncio.gather(*results)\n",
    "    # for item in results:\n",
    "    #     print(item.generations[0][0].text)\n",
    "    # print(results)\n",
    "    return results, llms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "ced43ba9503a976a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-17T13:25:02.685951Z",
     "start_time": "2024-11-17T13:24:53.751295Z"
    }
   },
   "outputs": [],
   "source": [
    "# results, llms = await run_api(api_keys, questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dfe761deaf4cf7a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-17T13:26:01.559087Z",
     "start_time": "2024-11-17T13:26:01.556473Z"
    }
   },
   "outputs": [],
   "source": [
    "# for item in results:\n",
    "#     print(item.generations[0][0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a4288b7eac184ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a8536b583771415d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-18T10:42:26.118975Z",
     "start_time": "2024-11-18T10:42:26.114935Z"
    }
   },
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"\n",
    "    请将以下表达式的计算结果返回为 JSON 格式：\n",
    "    {{\n",
    "      \"expression\": \"{question}\",\n",
    "      \"result\": ?\n",
    "    }}\n",
    "    \"\"\"\n",
    "\n",
    "questions = []\n",
    "labels = []\n",
    "\n",
    "for _ in range(90):\n",
    "    question, label = generate_arithmetic_expression(4)\n",
    "    questions.append(prompt_template.format(question=question))\n",
    "    labels.append(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "de9cca3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "90"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "928f0779adbc3eb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy: 90.0%\n",
      "executed in 00:00:13.345 (h:m:s.ms)\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "# for jupyter\n",
    "results, llms = await run_api(api_keys, questions)\n",
    "\n",
    "# 运行程序\n",
    "# results, llms = asyncio.run(run_api(api_keys, questions))\n",
    "# results, llms = asyncio.run(run_api(api_keys[:1], questions))\n",
    "right = 0\n",
    "except_cnt = 0\n",
    "not_equal = 0\n",
    "\n",
    "for q, res, label in zip(questions, results, labels):\n",
    "    res = res.generations[0][0].text\n",
    "    try:\n",
    "        res = re_parse_json(res)\n",
    "        if res is None:\n",
    "            except_cnt += 1\n",
    "            continue\n",
    "\n",
    "        res = res.get(\"result\", None)\n",
    "        if res is None:\n",
    "            except_cnt += 1\n",
    "            continue\n",
    "\n",
    "        res = int(res)\n",
    "        if res == label:\n",
    "            right += 1\n",
    "        else:\n",
    "            not_equal += 1\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        print(f\"question:{q}\\nresult:{res}\")\n",
    "\n",
    "print(\"accuracy: {}%\".format(right / len(questions) * 100))\n",
    "end_time = time.time()\n",
    "calculate_time_difference(start_time, end_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "58cd45af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(81, 0, 9)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "right, except_cnt, not_equal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4ba0f67b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.43333333333333335"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "13 / 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57f28878",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
